package start

import (
	"context"
	"golang.org/x/sync/semaphore"
	"log"
	"net/http"
	"net/http/httputil"
	"net/url"
	"sort"
	config2 "src/config"
	"strings"
)

type Proxy struct {
	//协议头 默认 http
	Protocol string `json:"protocol"`
	//排序 越大排序越高
	Sort int `json:"sort"`
	// 域名或IP
	Host string `json:"host"`
	//端口 默认80
	Port string `json:"port"`
	//要代理的地址
	ProxyPath string `json:"proxyPath"`
	//指向地址（会替换要代理地址）
	Path string `json:"path"`
}

type GatewayService struct {
	Proxy map[string]Proxy `yaml:"proxy"`
}
type Transport struct {
	Handle *Proxy
	http.RoundTripper
}

func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
	if t.RoundTripper == nil {
		t.RoundTripper = http.DefaultTransport
	}

	// 替换请求路径
	req = cloneRequest(req)

	req.URL.Path = strings.Replace(req.URL.Path, t.Handle.ProxyPath, t.Handle.Path, 1)
	return t.RoundTripper.RoundTrip(req)
}

// cloneRequest returns a shallow copy of the *http.Request.
func cloneRequest(r *http.Request) *http.Request {
	// 这里只是做了一个浅复制，如果有需要，可以实现深复制
	r2 := new(http.Request)
	*r2 = *r
	r2.URL = cloneURL(r.URL)
	return r2
}

// cloneURL returns a shallow copy of the *url.URL.
func cloneURL(u *url.URL) *url.URL {
	if u == nil {
		return nil
	}
	u2 := new(url.URL)
	*u2 = *u
	if u.User != nil {
		u2.User = new(url.Userinfo)
		*u2.User = *u.User
	}
	return u2
}

// 请求并发限制
type HttpBingfa struct {
	ctx *context.Context
	sem *semaphore.Weighted
}

var HttpBingfaV *HttpBingfa

func bingfaInit() {
	//并发线程数量
	concurrency := 100
	ctx := context.Background()
	sem := semaphore.NewWeighted(int64(concurrency))
	HttpBingfaV = &HttpBingfa{ctx: &ctx, sem: sem}
}
func (this *GatewayService) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	//defer HttpBingfaV.sem.Release(1)
	//err := HttpBingfaV.sem.Acquire(context.Background(), 1)
	//if err != nil {
	//	http.Error(w, err.Error(), http.StatusInternalServerError)
	//	return
	//}
	var remote *url.URL
	proxyRow := new(Proxy)
	for _, h := range this.Proxy {
		if strings.Contains(r.RequestURI, h.ProxyPath) {
			proxyRow = &h
			break
		}
	}
	if proxyRow == nil || proxyRow.Host == "" {
		log.Println(w, "404 Not Found")
		r.Close = true
		return
	}
	remote, _ = url.Parse(proxyRow.Protocol + "://" + proxyRow.Host + ":" + proxyRow.Port)
	proxy := httputil.NewSingleHostReverseProxy(remote)
	proxy.Transport = &Transport{
		Handle: proxyRow,
	}
	proxy.ServeHTTP(w, r)

}

type pairList []Proxy

// Len 实现sort.Interface的Len方法
func (p pairList) Len() int { return len(p) }

// Swap 实现sort.Interface的Swap方法
func (p pairList) Swap(i, j int) { p[i], p[j] = p[j], p[i] }

// Less 实现sort.Interface的Less方法，按value排序
func (p pairList) Less(i, j int) bool {
	return p[i].Sort < p[j].Sort
}
func StartGateway() {
	var err error
	log.Println("启动 gateway 服务 ")
	//bingfaInit()
	var proxy = make(map[string]Proxy)
	var gatewayProxyList pairList
	//配置代理数据格式转换成结构体
	err = config2.V().UnmarshalKey("gateway.proxy", &gatewayProxyList)
	for i, stu := range gatewayProxyList {
		if stu.Protocol == "" {
			gatewayProxyList[i].Protocol = config2.V().GetString("gateway.proxyProtocol")
		}
		if stu.Port == "" {
			gatewayProxyList[i].Port = config2.V().GetString("gateway.proxyPort")
		}
		if stu.Host == "" {
			gatewayProxyList[i].Host = config2.V().GetString("gateway.proxyHost")
		}
		//pairs = append(pairs, stu)
	}
	// 排序
	//sort.Sort(sort.Reverse(pairs)) // 逆序排序，如果要正序排序，去掉Reverse
	sort.Sort(sort.Reverse(gatewayProxyList))
	for _, p := range gatewayProxyList {
		proxy[p.ProxyPath] = p
	}
	// 注册被代理的服务器 (host， port)
	service := &GatewayService{
		Proxy: proxy,
	}
	gatewayAddr := config2.V().GetString("gateway.ip") + ":" + config2.V().GetString("gateway.port")
	log.Println("gateway 服务监听地址: ", gatewayAddr)
	err = http.ListenAndServe(gatewayAddr, service)
	if err != nil {
		log.Fatalln("ListenAndServe: ", err)
	}
}
